{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "#| eval: false\n",
    "! [ -e /content ] && pip install -Uqq fastai  # upgrade fastai on colab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "from __future__ import annotations\n",
    "from fastai.data.all import *\n",
    "from fastai.text.core import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp text.models.awdlstm\n",
    "#| default_cls_lvl 3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# AWD-LSTM\n",
    "\n",
    "> AWD LSTM from [Smerity et al.](https://arxiv.org/pdf/1708.02182.pdf) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Basic NLP modules"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "On top of the pytorch or the fastai [`layers`](/layers.html#layers), the language models use some custom layers specific to NLP."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def dropout_mask(\n",
    "    x:Tensor, # Source tensor, output will be of the same type as `x` \n",
    "    sz:list, # Size of the dropout mask as `int`s\n",
    "    p:float # Dropout probability\n",
    ") -> Tensor: # Multiplicative dropout mask\n",
    "    \"Return a dropout mask of the same type as `x`, size `sz`, with probability `p` to cancel an element.\"\n",
    "    return x.new_empty(*sz).bernoulli_(1-p).div_(1-p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = dropout_mask(torch.randn(3,4), [4,3], 0.25)\n",
    "test_eq(t.shape, [4,3])\n",
    "assert ((t == 4/3) + (t==0)).all()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class RNNDropout(Module):\n",
    "    \"Dropout with probability `p` that is consistent on the seq_len dimension.\"\n",
    "    def __init__(self, p:float=0.5): self.p=p\n",
    "\n",
    "    def forward(self, x):\n",
    "        if not self.training or self.p == 0.: return x\n",
    "        return x * dropout_mask(x.data, (x.size(0), 1, *x.shape[2:]), self.p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dp = RNNDropout(0.3)\n",
    "tst_inp = torch.randn(4,3,7)\n",
    "tst_out = dp(tst_inp)\n",
    "for i in range(4):\n",
    "    for j in range(7):\n",
    "        if tst_out[i,0,j] == 0: assert (tst_out[i,:,j] == 0).all()\n",
    "        else: test_close(tst_out[i,:,j], tst_inp[i,:,j]/(1-0.3))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It also supports doing dropout over a sequence of images where time dimesion is the 1st axis, 10 images of 3 channels and 32 by 32."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = dp(torch.rand(4,10,3,32,32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class WeightDropout(Module):\n",
    "    \"A module that wraps another layer in which some weights will be replaced by 0 during training.\"\n",
    "\n",
    "    def __init__(self, \n",
    "        module:nn.Module, # Wrapped module\n",
    "        weight_p:float, # Weight dropout probability \n",
    "        layer_names:str|MutableSequence='weight_hh_l0' # Name(s) of the parameters to apply dropout to\n",
    "    ):\n",
    "        self.module,self.weight_p,self.layer_names = module,weight_p,L(layer_names)\n",
    "        for layer in self.layer_names:\n",
    "            #Makes a copy of the weights of the selected layers.\n",
    "            w = getattr(self.module, layer)\n",
    "            delattr(self.module, layer)\n",
    "            self.register_parameter(f'{layer}_raw', nn.Parameter(w.data))\n",
    "            setattr(self.module, layer, w.clone())\n",
    "            if isinstance(self.module, (nn.RNNBase, nn.modules.rnn.RNNBase)):\n",
    "                self.module.flatten_parameters = self._do_nothing\n",
    "\n",
    "    def _setweights(self):\n",
    "        \"Apply dropout to the raw weights.\"\n",
    "        for layer in self.layer_names:\n",
    "            raw_w = getattr(self, f'{layer}_raw')\n",
    "            if self.training: w = F.dropout(raw_w, p=self.weight_p)\n",
    "            else: w = raw_w.clone()\n",
    "            setattr(self.module, layer, w)\n",
    "\n",
    "    def forward(self, *args):\n",
    "        self._setweights()\n",
    "        with warnings.catch_warnings():\n",
    "            # To avoid the warning that comes because the weights aren't flattened.\n",
    "            warnings.simplefilter(\"ignore\", category=UserWarning)\n",
    "            return self.module(*args)\n",
    "\n",
    "    def reset(self):\n",
    "        for layer in self.layer_names:\n",
    "            raw_w = getattr(self, f'{layer}_raw')\n",
    "            setattr(self.module, layer, raw_w.clone())\n",
    "        if hasattr(self.module, 'reset'): self.module.reset()\n",
    "\n",
    "    def _do_nothing(self): pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "module = nn.LSTM(5,7)\n",
    "dp_module = WeightDropout(module, 0.4)\n",
    "wgts = dp_module.module.weight_hh_l0\n",
    "tst_inp = torch.randn(10,20,5)\n",
    "h = torch.zeros(1,20,7), torch.zeros(1,20,7)\n",
    "dp_module.reset()\n",
    "x,h = dp_module(tst_inp,h)\n",
    "loss = x.sum()\n",
    "loss.backward()\n",
    "new_wgts = getattr(dp_module.module, 'weight_hh_l0')\n",
    "test_eq(wgts, getattr(dp_module, 'weight_hh_l0_raw'))\n",
    "assert 0.2 <= (new_wgts==0).sum().float()/new_wgts.numel() <= 0.6\n",
    "assert dp_module.weight_hh_l0_raw.requires_grad\n",
    "assert dp_module.weight_hh_l0_raw.grad is not None\n",
    "assert ((dp_module.weight_hh_l0_raw.grad == 0.) & (new_wgts == 0.)).any()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class EmbeddingDropout(Module):\n",
    "    \"Apply dropout with probability `embed_p` to an embedding layer `emb`.\"\n",
    "\n",
    "    def __init__(self,\n",
    "        emb:nn.Embedding, # Wrapped embedding layer\n",
    "        embed_p:float # Embdedding layer dropout probability \n",
    "    ):\n",
    "        self.emb,self.embed_p = emb,embed_p\n",
    "\n",
    "    def forward(self, words, scale=None):\n",
    "        if self.training and self.embed_p != 0:\n",
    "            size = (self.emb.weight.size(0),1)\n",
    "            mask = dropout_mask(self.emb.weight.data, size, self.embed_p)\n",
    "            masked_embed = self.emb.weight * mask\n",
    "        else: masked_embed = self.emb.weight\n",
    "        if scale: masked_embed.mul_(scale)\n",
    "        return F.embedding(words, masked_embed, ifnone(self.emb.padding_idx, -1), self.emb.max_norm,\n",
    "                           self.emb.norm_type, self.emb.scale_grad_by_freq, self.emb.sparse)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "enc = nn.Embedding(10, 7, padding_idx=1)\n",
    "enc_dp = EmbeddingDropout(enc, 0.5)\n",
    "tst_inp = torch.randint(0,10,(8,))\n",
    "tst_out = enc_dp(tst_inp)\n",
    "for i in range(8):\n",
    "    assert (tst_out[i]==0).all() or torch.allclose(tst_out[i], 2*enc.weight[tst_inp[i]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class AWD_LSTM(Module):\n",
    "    \"AWD-LSTM inspired by https://arxiv.org/abs/1708.02182\"\n",
    "    initrange=0.1\n",
    "\n",
    "    def __init__(self, \n",
    "        vocab_sz:int, # Size of the vocabulary\n",
    "        emb_sz:int, # Size of embedding vector\n",
    "        n_hid:int, # Number of features in hidden state\n",
    "        n_layers:int, # Number of LSTM layers\n",
    "        pad_token:int=1, # Padding token id\n",
    "        hidden_p:float=0.2, # Dropout probability for hidden state between layers\n",
    "        input_p:float=0.6, # Dropout probability for LSTM stack input\n",
    "        embed_p:float=0.1, # Embedding layer dropout probabillity\n",
    "        weight_p:float=0.5, # Hidden-to-hidden wight dropout probability for LSTM layers\n",
    "        bidir:bool=False # If set to `True` uses bidirectional LSTM layers\n",
    "    ):\n",
    "        store_attr('emb_sz,n_hid,n_layers,pad_token')\n",
    "        self.bs = 1\n",
    "        self.n_dir = 2 if bidir else 1\n",
    "        self.encoder = nn.Embedding(vocab_sz, emb_sz, padding_idx=pad_token)\n",
    "        self.encoder_dp = EmbeddingDropout(self.encoder, embed_p)\n",
    "        self.rnns = nn.ModuleList([self._one_rnn(emb_sz if l == 0 else n_hid, (n_hid if l != n_layers - 1 else emb_sz)//self.n_dir,\n",
    "                                                 bidir, weight_p, l) for l in range(n_layers)])\n",
    "        self.encoder.weight.data.uniform_(-self.initrange, self.initrange)\n",
    "        self.input_dp = RNNDropout(input_p)\n",
    "        self.hidden_dps = nn.ModuleList([RNNDropout(hidden_p) for l in range(n_layers)])\n",
    "        self.reset()\n",
    "\n",
    "    def forward(self, inp:Tensor, from_embeds:bool=False):\n",
    "        bs,sl = inp.shape[:2] if from_embeds else inp.shape\n",
    "        if bs!=self.bs: self._change_hidden(bs)\n",
    "\n",
    "        output = self.input_dp(inp if from_embeds else self.encoder_dp(inp))\n",
    "        new_hidden = []\n",
    "        for l, (rnn,hid_dp) in enumerate(zip(self.rnns, self.hidden_dps)):\n",
    "            output, new_h = rnn(output, self.hidden[l])\n",
    "            new_hidden.append(new_h)\n",
    "            if l != self.n_layers - 1: output = hid_dp(output)\n",
    "        self.hidden = to_detach(new_hidden, cpu=False, gather=False)\n",
    "        return output\n",
    "\n",
    "    def _change_hidden(self, bs):\n",
    "        self.hidden = [self._change_one_hidden(l, bs) for l in range(self.n_layers)]\n",
    "        self.bs = bs\n",
    "\n",
    "    def _one_rnn(self, n_in, n_out, bidir, weight_p, l):\n",
    "        \"Return one of the inner rnn\"\n",
    "        rnn = nn.LSTM(n_in, n_out, 1, batch_first=True, bidirectional=bidir)\n",
    "        return WeightDropout(rnn, weight_p)\n",
    "\n",
    "    def _one_hidden(self, l):\n",
    "        \"Return one hidden state\"\n",
    "        nh = (self.n_hid if l != self.n_layers - 1 else self.emb_sz) // self.n_dir\n",
    "        return (one_param(self).new_zeros(self.n_dir, self.bs, nh), one_param(self).new_zeros(self.n_dir, self.bs, nh))\n",
    "\n",
    "    def _change_one_hidden(self, l, bs):\n",
    "        if self.bs < bs:\n",
    "            nh = (self.n_hid if l != self.n_layers - 1 else self.emb_sz) // self.n_dir\n",
    "            return tuple(torch.cat([h, h.new_zeros(self.n_dir, bs-self.bs, nh)], dim=1) for h in self.hidden[l])\n",
    "        if self.bs > bs: return (self.hidden[l][0][:,:bs].contiguous(), self.hidden[l][1][:,:bs].contiguous())\n",
    "        return self.hidden[l]\n",
    "\n",
    "    def reset(self):\n",
    "        \"Reset the hidden states\"\n",
    "        [r.reset() for r in self.rnns if hasattr(r, 'reset')]\n",
    "        self.hidden = [self._one_hidden(l) for l in range(self.n_layers)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is the core of an AWD-LSTM model, with embeddings from `vocab_sz` and `emb_sz`, `n_layers` LSTMs potentially `bidir` stacked, the first one going from `emb_sz` to `n_hid`, the last one from `n_hid` to `emb_sz` and all the inner ones from `n_hid` to `n_hid`. `pad_token` is passed to the PyTorch embedding layer. The dropouts are applied as such:\n",
    "\n",
    "- the embeddings are wrapped in `EmbeddingDropout` of probability `embed_p`;\n",
    "- the result of this embedding layer goes through an `RNNDropout` of probability `input_p`;\n",
    "- each LSTM has `WeightDropout` applied with probability `weight_p`;\n",
    "- between two of the inner LSTM, an `RNNDropout` is applied with probability `hidden_p`.\n",
    "\n",
    "THe module returns two lists: the raw outputs (without being applied the dropout of `hidden_p`) of each inner LSTM and the list of outputs with dropout. Since there is no dropout applied on the last output, those two lists have the same last element, which is the output that should be fed to a decoder (in the case of a language model)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst = AWD_LSTM(100, 20, 10, 2, hidden_p=0.2, embed_p=0.02, input_p=0.1, weight_p=0.2)\n",
    "x = torch.randint(0, 100, (10,5))\n",
    "r = tst(x)\n",
    "test_eq(tst.bs, 10)\n",
    "test_eq(len(tst.hidden), 2)\n",
    "test_eq([h_.shape for h_ in tst.hidden[0]], [[1,10,10], [1,10,10]])\n",
    "test_eq([h_.shape for h_ in tst.hidden[1]], [[1,10,20], [1,10,20]])\n",
    "\n",
    "test_eq(r.shape, [10,5,20])\n",
    "test_eq(r[:,-1], tst.hidden[-1][0][0]) #hidden state is the last timestep in raw outputs\n",
    "\n",
    "tst.eval()\n",
    "tst.reset()\n",
    "tst(x);\n",
    "tst(x);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "#test bs change\n",
    "x = torch.randint(0, 100, (6,5))\n",
    "r = tst(x)\n",
    "test_eq(tst.bs, 6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "#| cuda\n",
    "tst = AWD_LSTM(100, 20, 10, 2, bidir=True).to('cuda')\n",
    "tst.reset()\n",
    "x = torch.randint(0, 100, (10,5)).to('cuda')\n",
    "r = tst(x)\n",
    "\n",
    "x = torch.randint(0, 100, (6,5), device='cuda')\n",
    "r = tst(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def awd_lstm_lm_split(model):\n",
    "    \"Split a RNN `model` in groups for differential learning rates.\"\n",
    "    groups = [nn.Sequential(rnn, dp) for rnn, dp in zip(model[0].rnns, model[0].hidden_dps)]\n",
    "    groups = L(groups + [nn.Sequential(model[0].encoder, model[0].encoder_dp, model[1])])\n",
    "    return groups.map(params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "awd_lstm_lm_config = dict(emb_sz=400, n_hid=1152, n_layers=3, pad_token=1, bidir=False, output_p=0.1,\n",
    "                          hidden_p=0.15, input_p=0.25, embed_p=0.02, weight_p=0.2, tie_weights=True, out_bias=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def awd_lstm_clas_split(model):\n",
    "    \"Split a RNN `model` in groups for differential learning rates.\"\n",
    "    groups = [nn.Sequential(model[0].module.encoder, model[0].module.encoder_dp)]\n",
    "    groups += [nn.Sequential(rnn, dp) for rnn, dp in zip(model[0].module.rnns, model[0].module.hidden_dps)]\n",
    "    groups = L(groups + [model[1]])\n",
    "    return groups.map(params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "awd_lstm_clas_config = dict(emb_sz=400, n_hid=1152, n_layers=3, pad_token=1, bidir=False, output_p=0.4,\n",
    "                            hidden_p=0.3, input_p=0.4, embed_p=0.05, weight_p=0.5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from nbdev import nbdev_export\n",
    "nbdev_export()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
